#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Apr 19 15:54:38 2022

@author: qiguangyao
"""


#%%Lib
import copy
import numpy as np
import pickle as pkl
import matplotlib.pyplot as plt
from scipy import stats
import seaborn as sns 
from scipy import asarray as ar,exp
from scipy.optimize import curve_fit
import math
import scipy.io as sio
import pingouin as pg
from sklearn import linear_model
from pylab import cos
import pandas as pd
import random
from statsmodels.formula.api import ols
from statsmodels.stats.anova import anova_lm
from statsmodels.sandbox.stats.multicomp import multipletests # for multiple comparisons correction
from statsmodels.stats.multicomp import pairwise_tukeyhsd
print("__file Output:",__file__)
#%%functions
import scipy.stats
def mean_confidence_interval(data, confidence=0.95):
    a = 1.0 * np.array(data)
    n = len(a)
    m, se = np.mean(a), scipy.stats.sem(a)
    h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
    return m, m-h, m+h

def adjust_spines(ax, spines):
    for loc, spine in ax.spines.items():
        if loc in spines:
            spine.set_position(('outward', 10))  # outward by 10 points
        else:
            spine.set_color('none')  # don't draw spine

    # turn off ticks where there is no spine
    if 'left' in spines:
        ax.yaxis.set_ticks_position('left')
    else:
        # no yaxis ticks
        ax.yaxis.set_ticks([])

    if 'bottom' in spines:
        ax.xaxis.set_ticks_position('bottom')
    else:
        # no xaxis ticks
        ax.xaxis.set_ticks([])
        
def gaus(x,a,x0,sigma):
    return a*(1/sigma*np.sqrt(2*np.pi))*exp(-(x-x0)**2/(2*sigma**2))

def gaussian(X, amp, cen, wid):
    return amp * exp(-(X-cen)**2 / wid)

def getPossionPDF(mu,x):
    if x > 170:
        x =170
    mu = mu + 0.01
    if x<0:
        x = 0
    # x[x<0]=0
    x = copy.deepcopy(round(x))
    out = math.exp(-mu)*(mu**x)/math.factorial(x)
    if out<0:
        out = 0
    return out

#tuning curve fitting
def vonMisesFunction(x,b,a,u):
    # import math
#    print(x - u)
    out = b + a*cos(x - u)
    out = np.array(out)
    out[out<0]=0
    # if out<0:
    #     out = 0
    return out

def getvonMisesParas(x,y):
    """
    x:hand position
    y:firing rate
    """
    init_vals = [1, 0, 1]  # for [b,a,u]
    best_vals, covar = curve_fit(vonMisesFunction, x, y, p0=init_vals,maxfev=500000)
    return best_vals

def getExpParas(x,y):
    """
    x:hand position
    y:firing rate
    """
    init_vals = [1, 0, 1]  # for [b,a,u]
    best_vals, covar = curve_fit(expFunction, x, y, p0=init_vals,maxfev=500000)
    return best_vals

def expFunction(x, a, b, c):
    return a * np.exp(-b * x) + c

#%% ------------figure 4---------------- 
#%import data
#4C, D
fig4Data = pkl.load(open('fig4Data.pickle','rb'))
dPCAPMCTrial5 = fig4Data['dPCAPMCTrial5']
dPCAArea5Trial5 = fig4Data['dPCAArea5Trial5']
explVarPMC = dPCAPMCTrial5['explVar']
ZfullPMC = dPCAPMCTrial5['Zfull']
explVarArea5 = dPCAArea5Trial5['explVar']
ZfullArea5 = dPCAArea5Trial5['Zfull']

#4E
corrRateRelaDrifHandRaw = fig4Data['corrRateRelaDrifHandRaw']
corrRateScoreBinSmooPMC = corrRateRelaDrifHandRaw['corrRateScoreBinSmooPMC']
corrRateScoreBinSmooArea5 = corrRateRelaDrifHandRaw['corrRateScoreBinSmooArea5']

#4F
jPECCPMC2Area5 = fig4Data['jPECCPMC2Area5']
asymJPECCIndePMC2Area5 = fig4Data['asymJPECCIndePMC2Area5']

#%%fig4C
explVar = copy.deepcopy(explVarPMC)#Premotor
# explVar = copy.deepcopy(explVarArea5)#Parietal
d = explVar['totalMarginalizedVar'] / explVar['totalVar'] * 100;
d = d[0][0][0]/sum(d[0][0][0])
with plt.style.context('style_paper.mplstyle'):
    labels = ['Target','Condition','Time']
    textprops = {"fontsize":4}
    explode = (0.05, 0.05, 0.05)  # only "explode" the 2nd slice (i.e. 'Hogs')
    fig1, ax1 = plt.subplots(figsize=(3.54/4,3.54/4))
    wedges, texts, autotexts = ax1.pie(d , 
                                       explode=explode,autopct='%1.1f%%',
                                       shadow=False, startangle=90, 
                                       textprops =textprops,
                                       colors = ['#1f77b4','#d62728','#7f7f7f'])
    for i, p in enumerate(wedges):
        p.set_alpha(1)
    plt.rc('font', size=8)
    plt.rc('legend', fontsize=8)
    plt.tight_layout()
    fileName = 'fig4C_dPCAExpVarPie'+'Premotor'+'.pdf'
    # plt.savefig(fileName,dpi = 600)
plt.show()
#%%
with plt.style.context('style_paper.mplstyle'):
    plt.subplots(ncols=1, nrows=1, sharey=True,figsize=(3.54/1.5,3.54/2))
    plt.bar([i+1 for i in range(15)],explVar['margVar'][0][0][0,range(15)],color = '#1f77b4',label = 'Target')
    plt.bar([i+1 for i in range(15)],explVar['margVar'][0][0][1,range(15)],color = '#d62728',label = '$P_{com}$')
    plt.bar([i+1 for i in range(15)],explVar['margVar'][0][0][2,range(15)],color = '#7f7f7f',label = 'Time')
    plt.xlim(left = 0.4)
    plt.xticks([1+2*i for i in range(8)],fontsize = 8)
    plt.yticks(fontsize = 8)
    plt.ylabel('Component variance (%)',fontsize = 8)
    plt.xlabel('Component',fontsize = 8)
    # plt.legend(bbox_to_anchor=[.095,.5],fontsize = 8)
    fileName = 'fig4C_dPCAExpVar'+'Premotor'+'.pdf'
    plt.tight_layout()
    # plt.savefig(fileName,dpi = 600)
plt.show()
#%%fig4D
compsPMC = [1,2,7,4,8,11,3,5,6] # PMC PCom-3,5,6
compsArea5 = [1,2,7,3,8,10,4,5,6] # all neurons Area5 PCom-4,5,6
Zfull = copy.deepcopy(ZfullArea5)
comps = copy.deepcopy(compsArea5)
alphas = np.linspace(0.1,1,22)

cond = ['VP','High','Low','P']
timeBins = [i*.1-.8 for i in range(22)]

with plt.style.context('style_paper.mplstyle'):
    fig = plt.figure(figsize=(3.54/1.5,3.54/1.5))
    ax = fig.add_subplot(111,projection='3d')
    colors = ['#CA2521', '#cc5e45','#9eaee5','#456ACF']
    # colors = ['#CA2521', 'k','gray','#456ACF']
    for j in range(4):
        for i in range(5):
            if i ==0:
                if j ==0:                    
                    ax.scatter(Zfull[6,i,j,3],Zfull[7,i,j,3],timeBins[3],color = 'gray')#,label = 'Disparity onset')
                    ax.scatter(Zfull[6,i,j,8],Zfull[7,i,j,8],timeBins[8],color = 'k')#,label = 'Target onset')
                ax.plot(Zfull[6,i,j,:],Zfull[7,i,j,:],timeBins,color = colors[j],label =  cond[j])#Zfull[8,i,j,:] 
            else:
                ax.plot(Zfull[6,i,j,:],Zfull[7,i,j,:],timeBins,color = colors[j])
            ax.scatter(Zfull[6,i,j,3],Zfull[7,i,j,3],timeBins[3],color = 'gray')
            ax.scatter(Zfull[6,i,j,8],Zfull[7,i,j,8],timeBins[8],color = 'k')
    ax.set_xlabel('$P_{com}$'+' PC1',labelpad=0) 
    ax.set_ylabel('$P_{com}$'+' PC2',labelpad=0)
    ax.set_zlabel('Time (s)',labelpad=0)
    
    ax.view_init(azim = -77, elev = 171)
    plt.legend()
plt.show()

#%% 
compsPMC = [1,2,7,4,8,11,3,5,6]#[1,2,7,5,8,10,3,4,6] # PMC PCom-3,5,6
compsArea5 = [1,2,7,3,8,10,4,5,6] # all neurons Area5 PCom-4,5,6
Zfull = copy.deepcopy(ZfullArea5)
comps = copy.deepcopy(compsArea5)
alphas = np.linspace(0.1,1,22)

cond = ['VP','High','Low','P']
timeBins = [i*.1-.8 for i in range(22)]

with plt.style.context('style_paper.mplstyle'):
    fig = plt.figure(figsize=(3.54/1.5,3.54/1.5))
    ax = fig.add_subplot(111,projection='3d')
    colors = ['#CA2521', '#cc5e45','#9eaee5','#456ACF']
    # colors = ['#CA2521', 'k','gray','#456ACF']
    for j in range(4):
        for i in range(5):
            if i ==0:
                if j ==0:                    
                    ax.scatter(timeBins[3],Zfull[6,i,j,3],Zfull[7,i,j,3],color = 'gray')#,label = 'Disparity onset')
                    ax.scatter(timeBins[8],Zfull[6,i,j,8],Zfull[7,i,j,8],color = 'k')#,label = 'Target onset')
                ax.plot(timeBins,Zfull[6,i,j,:],Zfull[7,i,j,:],color = colors[j],label =  cond[j])#Zfull[8,i,j,:] 
            else:
                ax.plot(timeBins,Zfull[6,i,j,:],Zfull[7,i,j,:],color = colors[j])
            ax.scatter(timeBins[3],Zfull[6,i,j,3],Zfull[7,i,j,3],color = 'gray')
            ax.scatter(timeBins[8],Zfull[6,i,j,8],Zfull[7,i,j,8],color = 'k')
    ax.set_ylabel('$P_{com}$'+' PC1',labelpad=0,fontsize = 8) 
    ax.set_zlabel('$P_{com}$'+' PC2',labelpad=0,fontsize = 8)
    ax.set_xlabel('Time (s)',labelpad=0,fontsize = 8)
    plt.xticks(fontsize = 8)
    plt.yticks(fontsize = 8)
    # ax.set_zticks(fontsize = 8)
    ax.tick_params(axis='z', which='major', labelsize=8)
    ax.view_init(azim = -77, elev = 171)
    # plt.legend()
    plt.tight_layout()
plt.show()

#%%fig4E
meanConfIntePMC = np.full((3,22),np.nan)
meanConfInteArea5 = np.full((3,22),np.nan)
inputData = copy.deepcopy(corrRateRelaDrifHandRaw)  #corrRateRelaDrifHandNormalizerC01

for i in range(22):
    data = inputData['corrRateScoreBinSmooPMC']['corrRateScoreBinSmooPMC'][range(50),i]
    meanConfIntePMC[:,i] = mean_confidence_interval(data, confidence=0.95)
    data = inputData['corrRateScoreBinSmooArea5']['corrRateScoreBinSmooArea5'][range(50),i]
    meanConfInteArea5[:,i] = mean_confidence_interval(data, confidence=0.95)
    
with plt.style.context('style_paper.mplstyle'):
    colors = plt.rcParams["axes.prop_cycle"].by_key()["color"][3:]
    f, (ax1) = plt.subplots(ncols=1, nrows=1, sharex=True,figsize=[7.25/3,3.54/2])
    ax1.plot([i*.1-.8 for i in range(22)],np.nanmean(inputData['corrRateScoreBinSmooPMC']['corrRateScoreBinSmooPMC'][range(50),:],axis = 0),color=colors[0],label = 'Premotor')
    ax1.fill_between([i*.1-.8 for i in range(22)],np.array(meanConfIntePMC[1,:]),np.array(meanConfIntePMC[2,:]), edgecolor=colors[0], facecolor=colors[0],alpha=0.4)
    ax1.plot([i*.1-.8 for i in range(22)],np.nanmean(inputData['corrRateScoreBinSmooArea5']['corrRateScoreBinSmooArea5'][range(50),:],axis = 0),color=colors[1],label = 'Parietal')
    ax1.fill_between([i*.1-.8 for i in range(22)],np.array(meanConfInteArea5[1,:]),np.array(meanConfInteArea5[2,:]), edgecolor=colors[1], facecolor=colors[1],alpha=0.4)
    ax1.plot([i*.1-.8 for i in range(7,22)],[.92 for i in range(7,22) ],'-',color=colors[0],linewidth = 1) #7
    ax1.plot([i*.1-.8 for i in range(10,22)],[.90 for i in range(10,22) ],'-',color=colors[1],linewidth = 1) #11
    ax1.plot([i*.1-.8 for i in range(22)],[.5 for i in range(22)], color = 'k',ls= '--')
    ax1.set_yticks(np.arange(.45, .9, step=0.1))
    ax1.set_ylim([0.45,.94])
    ax1.set_ylabel('Decoding accuracy')
    ax1.fill_between([i*.1-.8 for i in range(3,9)],[.45 for i in range(3,9)],[.94 for i in range(3,9)], edgecolor=[], facecolor='gray',alpha=0.4)    
    ax1.set_xlabel('Time from target onset (s)')    
    plt.legend(loc = 'upper left',ncol = 1, bbox_to_anchor=[-.04,1] )
    ax1.set_xticks(np.arange(-1, 1.51, step=0.5))
    ax1.set_xlim([-.9, 1.4])
    plt.tight_layout()
    fileName = 'fig4E_rawDecodingPcom.pdf'
    # plt.savefig(fileName,dpi = 600)
plt.show()
#%%fig4F
colors = plt.cm.Set2.colors
ll = [-0.8,-0.7,-0.6,-0.5,-0.4,-0.3,-0.2,-0.1,0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1,1.1,1.2,1.3]
ll2 = [1.3,1.2,1.1,1,.9,.8,.7,.6,.5,.4,.3,.2,.1,0,-.1,-.2,-.3,-.4,-.5,-.6,-.7,-.8]

with plt.style.context('style_paper.mplstyle'):
    # sns.set(style = "ticks")
    f, ax1 = plt.subplots(ncols=1, nrows=1, sharey=True,figsize=[3.54/1.2,3.54/2])
    g1 = sns.heatmap(
                        jPECCPMC2Area5,
                     yticklabels =ll2,
                     vmin=-.1, vmax=.65,
                      xticklabels = [-0.8,-0.7,-0.6,-0.5,-0.4,-0.3,-0.2,-0.1,0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1,1.1,1.2,1.3],
                     # cbar=False,
                     cmap = 'viridis',
                     square = True,
                     ax = ax1)
    ax1.plot([i for i in range(23)],[22-i for i in range(23)],color = 'w')
    ax1.plot([i for i in range(23)],[14 for i in range(23)],color = 'w')
    ax1.plot([8 for i in range(23)],[i for i in range(23)],color = 'w')
    g1.set_xlabel('Parietal')
    g1.set_ylabel('Premotor')    
    ax1.set_xticks([3,8,13,18])
    ax1.set_xticklabels(['-0.5','0','0.5','1'],rotation = 0)
    ax1.set_yticks([4,9,14,19])
    
    ax1.set_yticklabels(['1','0.5','0','-0.5'],rotation = 0)
    plt.tight_layout()
    fileName = 'fig4F_jPECCMatr.pdf'
    # plt.savefig(fileName,dpi = 600)
#%%
timeBins = [.1*i-.5 for i in range(22-3)]
with plt.style.context('style_paper.mplstyle'):
    f, ax1 = plt.subplots(ncols=1, nrows=1, sharey=True,figsize=[3.54/1.8,3.54/2])
    ax1.plot(timeBins[0:18],asymJPECCIndePMC2Area5,label = 'aligned',color = 'k')#,color = '#1f77b4')
    ax1.plot(timeBins[9:14],[0.3 for j in range(9,14)],'-',color = 'k')#,color = '#1f77b4')
    # ax1.plot(timeBins[0:18],[0 for j in range(0,18)],'--',color = 'k')#,color = '#1f77b4')
    plt.xlabel('Time from target onset (s)')
    plt.ylabel('asymmetric index')#\n(-0.3 to +0.3s)')
    plt.xticks(np.arange(-1,1.1,.5))
    plt.xlim(-.5,1.3)
    plt.yticks([0.00,0.25])
    plt.ylim([-.05, .31])
    fileName = 'fig4F_jPECCAsymInde.pdf'
    # plt.savefig(fileName,dpi = 600)
plt.show()



